load("//bazel:api.bzl", "lit_tests")

package(default_visibility = ["//max:consumers"])

# NVIDIA tests
lit_tests(
    name = "test",
    srcs = glob(
        ["**/*.mojo"],
        exclude = ["test_rocshmem.mojo"],
    ),
    env = select({
        "//:nvidia_gpu": {"MODULAR_SHMEM_LIB_DIR": "../+http_archive+nvshmem_prebuilt"},
        "//:amd_gpu": {"MODULAR_SHMEM_LIB_DIR": "../+http_archive+rocshmem/lib"},
        "//conditions:default": {},
    }),
    exec_properties = {
        "test.resources:gpu-memory": "20",
    },
    gpu_constraints = [
        "//:has_4_gpus",
        "//:nvidia_gpu",
    ],
    mojo_deps = [
        "//max:layout",
        "//max:shmem",
        "@mojo//:stdlib",
    ],
    tags = [
        "gpu",
    ],
)

# AMD tests
lit_tests(
    name = "test_rocshmem",
    srcs = ["test_rocshmem.mojo"],
    data = [
        "@rocshmem//:host",
    ],
    env = select({
        "//:amd_gpu": {
            "LD_LIBRARY_PATH": "../+http_archive+rocshmem/lib",
            "OPAL_PREFIX": "../+http_archive+rocshmem",
        },
        "//conditions:default": {},
    }),
    exec_properties = {
        "test.resources:gpu-memory": "4",
    },
    gpu_constraints = select({
        "//:mi300x_gpu": ["//:has_multi_gpu"],
        "//:mi355_gpu": ["//:has_multi_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    mojo_deps = [
        "//max:layout",
        "//max:shmem",
        "@mojo//:stdlib",
    ],
    tags = [
        "gpu",
    ],
    tools = [
        "@rocshmem//:bin/mpirun",
        "@rocshmem//:bin/prterun",
    ],
)
